{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h2o4gpu\n",
    "import h2o4gpu.util.import_data as io\n",
    "import h2o4gpu.util.metrics as metrics\n",
    "from tabulate import tabulate\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading Data with Pandas\n",
      "(999, 9733)\n",
      "Original m=999 n=9732\n",
      "Size of Train rows=800 & valid rows=199\n",
      "Size of Train cols=9732 valid cols=9732\n",
      "Size of Train cols=9733 & valid cols=9733 after adding intercept column\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Import Data for H2O GPU Edition\n",
    "\n",
    "This function will read in data and prepare it for H2O4GPU's GLM solver\n",
    "\n",
    "Parameters\n",
    "----------\n",
    "data_path : str\n",
    "             A path to a dataset (The dataset needs to be all numeric)\n",
    "use_pandas : bool\n",
    "              Indicate if Pandas should be used to parse\n",
    "intercept : bool\n",
    "              Indicate if intercept term is needed\n",
    "valid_fraction : float\n",
    "                  Percentage of dataset reserved for a validation set\n",
    "classification : bool\n",
    "                  Classification problem?\n",
    "Returns\n",
    "-------\n",
    "If valid_fraction > 0 it will return the following:\n",
    "    train_x: numpy array of train input variables\n",
    "    train_y: numpy array of y variable\n",
    "    valid_x: numpy array of valid input variables\n",
    "    valid_y: numpy array of valid y variable\n",
    "    family : string that would either be \"logistic\" if classification is set to True, otherwise \"elasticnet\"\n",
    "If valid_fraction == 0 it will return the following:\n",
    "    train_x: numpy array of train input variables\n",
    "    train_y: numpy array of y variable\n",
    "    family : string that would either be \"logistic\" if classification is set to True, otherwise \"elasticnet\"\n",
    "\"\"\"\n",
    "import os\n",
    "if not os.path.exists(\"ipums_1k.csv\"):\n",
    "    !wget https://s3.amazonaws.com/h2o-public-test-data/h2o4gpu/open_data/ipums_1k.csv\n",
    "train_x,train_y,valid_x,valid_y,family=io.import_data(data_path=\"ipums_1k.csv\", \n",
    "                                                        use_pandas=True, \n",
    "                                                        intercept=True,\n",
    "                                                        valid_fraction=0.2,\n",
    "                                                        classification=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up solver\n",
      "Running h2o4gpu Lasso Regression\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Set up instance of H2O4GPU's Lasso solver with default parameters\n",
    "\n",
    "Need to pass in `family` to indicate problem type to solve\n",
    "\"\"\"\n",
    "print(\"Setting up solver\")\n",
    "model = h2o4gpu.Lasso()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Solving\n",
      "CPU times: user 17.6 s, sys: 4.13 s, total: 21.8 s\n",
      "Wall time: 21.8 s\n",
      "Done Solving\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Fit Lasso Solver\n",
    "\"\"\"\n",
    "print(\"Solving\")\n",
    "%time model.fit(train_x, train_y)\n",
    "print(\"Done Solving\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  1.09889756e+04   3.75698984e+04   5.33906016e+04   3.94701289e+04\n",
      "    1.99471836e+04   3.26172285e+04   1.46580322e+04   3.38587031e+04\n",
      "    2.57400723e+04   5.38813477e+04  -4.55054321e+02   6.01770312e+03\n",
      "    1.17396279e+04   1.83504414e+04   5.06654219e+04   4.08623359e+04\n",
      "    1.75025195e+04   1.76822148e+04   1.50483564e+04   1.10771023e+05\n",
      "    4.53170215e+03   2.32612793e+04   1.49927441e+04   2.93769385e+03\n",
      "    7.86341357e+03   3.90922188e+04   4.21069492e+04   7.68087578e+04\n",
      "    2.41864727e+04   8.20444031e+01   2.97903848e+04   2.61350742e+04\n",
      "    1.56964771e+03   1.34936582e+04  -5.14499939e+02  -3.15045837e+02\n",
      "    1.40303164e+04   1.81600508e+04   1.19846484e+04   7.00190156e+04\n",
      "   -3.97029370e+03  -1.00706309e+04   4.59641484e+04   2.59598672e+04\n",
      "   -1.22267793e+04   1.15692803e+04   3.72865547e+04   3.06964609e+04\n",
      "    2.13391680e+04   3.15113086e+04   1.19520117e+04   2.00309551e+04\n",
      "    2.74033320e+04   2.17361055e+04  -2.11075391e+03   3.58349072e+03\n",
      "   -6.49273071e+02   1.48605615e+04   3.38002305e+04   2.22022305e+04\n",
      "    3.98001211e+04   3.56661401e+03   2.28747344e+04   1.03059199e+04\n",
      "    4.58545312e+04   2.24593750e+04   5.39659141e+04   6.01053164e+04\n",
      "    1.86068594e+04   8.31790391e+04   3.50255586e+04   4.23314819e+02\n",
      "   -5.24913184e+03   3.83369180e+04   2.56234473e+04   1.02244189e+04\n",
      "    4.17390312e+04   6.64631250e+03   3.87100977e+04   1.73529766e+04\n",
      "    3.44353203e+04   2.23016562e+04   3.71430117e+04   1.51529307e+04\n",
      "    4.03290352e+04   3.16899746e+04   7.45363125e+04   1.63988984e+04\n",
      "    2.35902109e+04   1.99843027e+04   1.37177920e+04   1.39650635e+04\n",
      "    5.57664531e+04   1.08431914e+05   7.31540625e+03   3.33538125e+04\n",
      "    1.25796172e+04   3.41086250e+04   2.52959414e+04   3.11856699e+04\n",
      "    2.22621816e+04   1.31161846e+04   2.44573574e+04   5.72742852e+04\n",
      "    2.12948457e+04   3.96020234e+04   6.48818125e+04   3.65094189e+03\n",
      "    1.69439343e+03   2.39442334e+03   6.70209814e+03   1.54435205e+04\n",
      "   -4.64320215e+03   2.76259785e+04   1.35587207e+04   1.87338613e+04\n",
      "    4.84235254e+03   4.61896387e+03   1.53843359e+04   1.26019600e+04\n",
      "    1.65664922e+04   3.80392227e+04   9.17164258e+03   1.50349844e+04\n",
      "    3.65161953e+04   2.12460723e+04   9.43408398e+03   3.14460859e+04\n",
      "    3.42797578e+04   4.89760273e+04   1.57843398e+04   3.87667725e+03\n",
      "   -3.99689746e+03   2.03973711e+04   1.04191766e+05  -5.37930859e+03\n",
      "    4.31113906e+04  -2.60748291e+03   7.35169531e+03   1.95001895e+04\n",
      "    3.67719297e+04   9.43949766e+04   5.54418438e+04   1.92502969e+04\n",
      "    2.56520469e+04   2.88755176e+04   2.03221973e+04   6.88230273e+03\n",
      "    3.11458184e+04   3.60932617e+04   6.45547227e+04   7.79952100e+03\n",
      "    1.27797969e+04   1.79521367e+04   1.62155215e+04   3.10070176e+04\n",
      "    2.23901426e+04   2.04088086e+04   7.22836641e+04   1.86945820e+04\n",
      "    1.14795869e+04   3.42041016e+04   1.19944141e+04   3.33064209e+03\n",
      "   -5.38102490e+03   2.81105566e+04   3.22572266e+04   5.58630078e+04\n",
      "    2.57800508e+04   2.96762812e+04   1.74843730e+04   6.84701484e+04\n",
      "    3.10810469e+04   2.62613711e+04   2.57703184e+04   3.12589141e+04\n",
      "    2.98682969e+04   6.25953242e+04   5.23314531e+04   1.26882588e+04\n",
      "    4.33346924e+03   2.35517266e+04   5.38530195e+04   6.67725391e+04\n",
      "    2.33688926e+04   2.94185742e+04   2.76960820e+04   1.56924434e+04\n",
      "    1.09514238e+04   1.46083848e+04   1.61235576e+04   1.58833721e+04\n",
      "    2.03798750e+04   1.90066660e+04   4.01609219e+04   1.53184971e+04\n",
      "    1.94606895e+04   7.82145215e+03   2.27939727e+04]]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Make predictions on validation set\n",
    "\"\"\"\n",
    "preds = model.predict(valid_x)\n",
    "print(preds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
